package com.wux.labs.spring.springbucks.filter

import com.alibaba.druid.filter.FilterEventAdapter
import com.alibaba.druid.proxy.jdbc.StatementProxy
import com.alibaba.druid.sql.dialect.h2.parser.H2StatementParser
import com.alibaba.druid.sql.parser.Token
import com.wux.labs.spring.springbucks.exception.InterceptException
import java.util.*

/**
 * 自定义SQL拦截器。实现无侵入式 SQL 校验功能。
 */
class InterceptFilter : FilterEventAdapter() {
    override fun statementExecuteBefore(statement: StatementProxy, sql: String) {
        if (sql.contains(";")) {
            val lexer = H2StatementParser(sql).lexer
            var token = lexer.token()
            var prevToken = token
            var size = 1
            while (token != Token.EOF) {
                if (token == Token.SEMI && prevToken != Token.SEMI) {
                    size++
                }
                prevToken = token
                lexer.nextToken()
                token = lexer.token()
            }
            if (prevToken == Token.SEMI) {
                size--
            }
            if (size > 1) {
                throw InterceptException("不允许进行SQL拼接")
            }
        }
        if (sql.contains(" in", true)) {
            val sizeMap = mutableMapOf<Int, Int>()
            var lparenPos = 0

            val lparenPosStack = Stack<Pair<Int, Boolean>>()

            val lexer = H2StatementParser(sql).lexer
            var token = lexer.token()
            var prevToken = token

            while (token != Token.EOF) {
                val peek = if (lparenPosStack.isNotEmpty()) {
                    lparenPosStack.peek()
                } else {
                    Pair(0, false)
                }
                val countFlag = peek.second
                if (token == Token.LPAREN) {
                    lparenPos++
                    lparenPosStack.push(Pair(lparenPos, prevToken == Token.IN))
                }
                if (token == Token.RPAREN) {
                    lparenPosStack.pop()
                }

                if (token == Token.COMMA && countFlag) {
                    sizeMap[peek.first] = (sizeMap[peek.first] ?: 0) + 1
                }

                prevToken = token
                lexer.nextToken()
                token = lexer.token()
            }
            if (maxOf(0, *sizeMap.values.toTypedArray()) + 1 > 10) {
                throw InterceptException("IN列表不能超过10个")
            }
        }
        super.statementExecuteBefore(statement, sql)
    }
}